import requests
import json
from typing import List, Dict, Any
from sqlalchemy.orm import Session
from config import DEEPSEEK_API_KEY, DEEPSEEK_BASE_URL, DEEPSEEK_MODEL, QUESTION_GENERATION_SYSTEM_PROMPT
import crud, database

class LLMAPI:
    def __init__(self, db: Session = None):
        # 优先从数据库获取配置，如果没有则使用环境变量
        if db:
            active_config = crud.get_active_llm_config(db)
            if active_config:
                self.api_key = active_config.api_key
                self.base_url = active_config.base_url
                self.model = active_config.model_name
                self.config_name = active_config.config_name
            else:
                # 如果数据库中没有激活配置，使用环境变量
                self.api_key = DEEPSEEK_API_KEY
                self.base_url = DEEPSEEK_BASE_URL
                self.model = DEEPSEEK_MODEL
                self.config_name = "DeepSeek默认配置"
        else:
            # 如果没有数据库连接，使用环境变量
            self.api_key = DEEPSEEK_API_KEY
            self.base_url = DEEPSEEK_BASE_URL
            self.model = DEEPSEEK_MODEL
            self.config_name = "DeepSeek默认配置"
        
        # 根据不同模型设置不同的请求头
        self.headers = self._build_headers()
    
    def _build_headers(self):
        """根据模型类型构建请求头"""
        if 'gemini' in self.model.lower():
            # Gemini使用不同的认证方式
            return {
                "x-goog-api-key": self.api_key,
                "Content-Type": "application/json"
            }
        else:
            # OpenAI兼容格式（DeepSeek、OpenAI等）
            return {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json"
            }
    
    def _build_request(self, enhanced_prompt: str, max_tokens: int):
        """根据模型类型构建请求数据和API URL"""
        if 'gemini' in self.model.lower():
            # Gemini API格式
            request_data = {
                "contents": [
                    {
                        "parts": [
                            {
                                "text": f"{QUESTION_GENERATION_SYSTEM_PROMPT}\n\n{enhanced_prompt}"
                            }
                        ]
                    }
                ],
                "generationConfig": {
                    "temperature": 0.3,
                    "maxOutputTokens": min(max_tokens * 2, 8192),
                    "topP": 0.8,
                    "topK": 40
                }
            }
            # 确保Gemini API URL格式正确
            # Gemini模型名称需要特殊处理
            model_name = self.model
            if model_name == "gemini-1.5-flash-latest":
                model_name = "gemini-1.5-flash"
            elif model_name == "gemini-1.5-pro-latest":
                model_name = "gemini-1.5-pro"
            elif model_name == "gemini-2.5-flash-latest":
                model_name = "gemini-2.5-flash"
            # 如果模型名称已经是正确格式，直接使用
            elif model_name in ["gemini-1.5-flash", "gemini-1.5-pro", "gemini-2.5-flash"]:
                model_name = model_name
            
            if self.base_url.endswith('/models'):
                api_url = f"{self.base_url}/{model_name}:generateContent"
            else:
                api_url = f"{self.base_url.rstrip('/')}/models/{model_name}:generateContent"
        else:
            # OpenAI兼容格式（DeepSeek、OpenAI等）
            request_data = {
                "model": self.model,
                "messages": [
                    {
                        "role": "system",
                        "content": QUESTION_GENERATION_SYSTEM_PROMPT
                    },
                    {
                        "role": "user",
                        "content": enhanced_prompt
                    }
                ],
                "temperature": 0.7,
                "max_tokens": max_tokens,
                "stream": False
            }
            api_url = f"{self.base_url}/chat/completions"
        
        return request_data, api_url
    
    def _extract_content(self, response_data: dict) -> str:
        """从不同模型的响应中提取生成的内容"""
        if 'gemini' in self.model.lower():
            # Gemini API响应格式
            if "candidates" not in response_data or len(response_data["candidates"]) == 0:
                raise Exception("Gemini API响应格式错误，未找到candidates字段")
            
            candidate = response_data["candidates"][0]
            
            # 检查finishReason
            finish_reason = candidate.get("finishReason", "")
            if finish_reason == "MAX_TOKENS":
                raise Exception("Gemini API响应被截断，请减少生成内容的长度或增加maxOutputTokens")
            elif finish_reason == "SAFETY":
                raise Exception("Gemini API因安全原因拒绝生成内容")
            elif finish_reason in ["RECITATION", "OTHER"]:
                raise Exception(f"Gemini API生成失败: {finish_reason}")
            
            if "content" not in candidate:
                raise Exception("Gemini API响应格式错误，未找到content字段")
            
            content = candidate["content"]
            if "parts" not in content or len(content["parts"]) == 0:
                raise Exception("Gemini API响应格式错误，content中未找到parts字段或parts为空")
            
            parts = content["parts"]
            if "text" not in parts[0]:
                raise Exception("Gemini API响应格式错误，未找到text字段")
            
            return parts[0]["text"].strip()
        else:
            # OpenAI兼容格式（DeepSeek、OpenAI等）
            if "choices" not in response_data or len(response_data["choices"]) == 0:
                raise Exception("API响应格式错误，未找到choices字段")
            
            return response_data["choices"][0]["message"]["content"].strip()
    
    def generate_questions(self, user_prompt: str, count: int = 5, difficulty: str = None, topic: str = None) -> List[Dict[str, Any]]:
        """
        根据用户输入的prompt生成题目JSON数据
        
        Args:
            user_prompt: 用户输入的题目生成提示
            count: 生成题目数量，默认5道
            difficulty: 难度等级 (easy/medium/hard)，可选
            topic: 主题分类，可选
            
        Returns:
            生成的题目列表
            
        Raises:
            Exception: API调用失败或响应解析失败
        """
        try:
            # 基本校验：API Key
            if not self.api_key or not self.api_key.strip():
                raise Exception(f"未配置API密钥，请在{self.config_name}配置中设置API Key")

            # 构建增强的用户提示
            enhanced_prompt = self._build_enhanced_prompt(user_prompt, count, difficulty, topic)
            
            # 根据题目数量调整max_tokens
            max_tokens = min(4000, max(1000, count * 400))
            
            # 构建请求数据和URL
            request_data, api_url = self._build_request(enhanced_prompt, max_tokens)
            
            # 添加调试日志
            print(f"[DEBUG] 使用模型: {self.model}")
            print(f"[DEBUG] API URL: {api_url}")
            print(f"[DEBUG] 请求头: {self.headers}")
            print(f"[DEBUG] 配置名称: {self.config_name}")
            
            # 发送请求，必要时回退到 /v1 基础路径
            max_retries = 3
            retry_count = 0
            
            while retry_count < max_retries:
                try:
                    response = requests.post(
                        api_url,
                        headers=self.headers,
                        json=request_data,
                        timeout=60  # 增加超时时间到60秒
                    )
                    break  # 成功则跳出重试循环
                    
                except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
                    retry_count += 1
                    
                    # 对于非Gemini模型，如果是第一次重试且URL不包含/v1，尝试添加/v1路径
                    if (retry_count == 1 and 
                        'gemini' not in self.model.lower() and 
                        not self.base_url.rstrip('/').endswith('/v1')):
                        alt_base = self.base_url.rstrip('/') + '/v1'
                        try:
                            response = requests.post(
                                f"{alt_base}/chat/completions",
                                headers=self.headers,
                                json=request_data,
                                timeout=60
                            )
                            # 更新为备用成功的 base_url
                            self.base_url = alt_base
                            break
                        except:
                            pass  # 继续重试原URL
                    
                    # 如果达到最大重试次数，抛出异常
                    if retry_count >= max_retries:
                        if isinstance(e, requests.exceptions.Timeout):
                            raise Exception(f"API请求超时（已重试{max_retries}次），请检查网络连接或稍后重试")
                        else:
                            raise Exception(f"网络连接错误，无法连接到 {self.base_url}/chat/completions（已重试{max_retries}次）。请检查：\n1. 网络连接是否正常\n2. API Base URL是否正确\n3. 是否需要代理设置\n详情：{str(e)}")
                    
                    # 等待一段时间后重试
                    import time
                    time.sleep(2 ** retry_count)  # 指数退避：2秒、4秒、8秒
            
            # 检查响应状态
            if response.status_code != 200:
                error_msg = f"API请求失败，状态码: {response.status_code}"
                
                if response.status_code == 401:
                    error_msg += "\n错误原因：API密钥无效或已过期，请检查配置中的API Key是否正确"
                elif response.status_code == 403:
                    error_msg += "\n错误原因：API访问被拒绝，请检查API密钥权限或账户余额"
                elif response.status_code == 429:
                    error_msg += "\n错误原因：API请求频率超限，请稍后重试"
                elif response.status_code == 500:
                    error_msg += "\n错误原因：API服务器内部错误，请稍后重试"
                else:
                    error_msg += f"\n响应内容: {response.text[:200]}..."
                
                raise Exception(error_msg)
            
            # 解析响应
            response_data = response.json()
            
            # 获取生成的内容
            generated_content = self._extract_content(response_data)
            
            # 尝试解析JSON
            try:
                questions = json.loads(generated_content)
                
                # 验证数据格式
                if not isinstance(questions, list):
                    raise Exception("生成的数据不是数组格式")
                
                # 验证每个题目的字段
                required_fields = ["question_content", "option_a", "option_b", "option_c", "option_d", "answer", "knowledge_point"]
                optional_fields = ["explanation", "wrong_analysis", "difficulty_level", "topic_category"]
                
                for i, question in enumerate(questions):
                    if not isinstance(question, dict):
                        raise Exception(f"第{i+1}道题目不是对象格式")
                    
                    # 验证必需字段
                    for field in required_fields:
                        if field not in question:
                            raise Exception(f"第{i+1}道题目缺少必需字段: {field}")
                        if not question[field] or not str(question[field]).strip():
                            raise Exception(f"第{i+1}道题目的{field}字段不能为空")
                    
                    # 验证答案格式
                    if question["answer"] not in ["A", "B", "C", "D"]:
                        raise Exception(f"第{i+1}道题目的答案格式错误，应为A、B、C或D")
                    
                    # 验证难度等级（如果存在）
                    if "difficulty_level" in question and question["difficulty_level"]:
                        if question["difficulty_level"] not in ["easy", "medium", "hard"]:
                            raise Exception(f"第{i+1}道题目的难度等级错误，应为easy、medium或hard")
                    
                    # 为缺失的可选字段设置默认值
                    for field in optional_fields:
                        if field not in question:
                            if field == "explanation":
                                question[field] = "详细解析待补充"
                            elif field == "wrong_analysis":
                                question[field] = "错误选项分析待补充"
                            elif field == "difficulty_level":
                                question[field] = "medium"
                            elif field == "topic_category":
                                question[field] = "信息技术基础"
                
                return questions
                
            except json.JSONDecodeError as e:
                # 如果JSON解析失败，尝试提取JSON部分
                try:
                    # 查找JSON数组的开始和结束
                    start_idx = generated_content.find('[')
                    end_idx = generated_content.rfind(']') + 1
                    
                    if start_idx != -1 and end_idx != 0:
                        json_content = generated_content[start_idx:end_idx]
                        questions = json.loads(json_content)
                        
                        # 再次验证格式
                        if isinstance(questions, list) and len(questions) > 0:
                            return questions
                    
                    raise Exception(f"无法解析生成的JSON数据: {str(e)}")
                except:
                    raise Exception(f"生成的内容不是有效的JSON格式: {generated_content[:200]}...")
            
        except requests.exceptions.Timeout:
            raise Exception("API请求超时，请稍后重试")
        except requests.exceptions.ConnectionError as e:
            raise Exception(f"网络连接错误，请检查网络连接（{self.base_url}）；详情：{str(e)}")
        except requests.exceptions.RequestException as e:
            raise Exception(f"请求错误: {str(e)}")
        except Exception as e:
            raise Exception(f"生成题目失败: {str(e)}")
    
    def _build_enhanced_prompt(self, user_prompt: str, count: int, difficulty: str, topic: str) -> str:
        """
        构建增强的用户提示
        
        Args:
            user_prompt: 原始用户提示
            count: 题目数量
            difficulty: 难度等级
            topic: 主题分类
            
        Returns:
            增强后的提示文本
        """
        enhanced_parts = []
        
        # 基础要求
        enhanced_parts.append(f"请生成{count}道高中信息技术选择题。")
        
        # 难度要求
        if difficulty:
            difficulty_map = {
                "easy": "简单（基础概念理解）",
                "medium": "中等（概念应用和分析）", 
                "hard": "困难（综合应用和深度思考）"
            }
            if difficulty in difficulty_map:
                enhanced_parts.append(f"难度等级：{difficulty_map[difficulty]}")
        
        # 主题要求
        if topic:
            enhanced_parts.append(f"重点关注主题：{topic}")
        
        # 用户原始需求
        enhanced_parts.append(f"具体要求：{user_prompt}")
        
        # 质量要求
        enhanced_parts.append("请确保：")
        enhanced_parts.append("1. 题目之间有适当的多样性，避免重复")
        enhanced_parts.append("2. 选项设计合理，有一定迷惑性但不误导")
        enhanced_parts.append("3. 知识点覆盖全面，符合教学目标")
        enhanced_parts.append("4. 解析详细，有助于学生理解")
        
        return "\n".join(enhanced_parts)
    
    def generate_questions_batch(self, prompts: List[str], count_per_prompt: int = 3) -> List[Dict[str, Any]]:
        """
        批量生成题目
        
        Args:
            prompts: 提示列表
            count_per_prompt: 每个提示生成的题目数量
            
        Returns:
            所有生成的题目列表
        """
        all_questions = []
        
        for i, prompt in enumerate(prompts):
            try:
                questions = self.generate_questions(prompt, count_per_prompt)
                all_questions.extend(questions)
            except Exception as e:
                # 记录错误但继续处理其他提示
                print(f"批量生成第{i+1}个提示失败: {str(e)}")
                continue
        
        return all_questions

# 创建全局实例（用于向后兼容）
llm_api = LLMAPI()

def generate_questions_from_prompt(prompt: str, db: Session = None, count: int = 5, difficulty: str = None, topic: str = None) -> List[Dict[str, Any]]:
    """
    便捷函数：根据prompt生成题目
    
    Args:
        prompt: 用户输入的题目生成提示
        db: 数据库会话，用于获取配置
        count: 生成题目数量，默认5道
        difficulty: 难度等级 (easy/medium/hard)，可选
        topic: 主题分类，可选
    
    Returns:
        生成的题目列表
    """
    if db:
        # 使用数据库配置创建新的API实例
        api_instance = LLMAPI(db)
        return api_instance.generate_questions(prompt, count, difficulty, topic)
    else:
        # 使用全局实例（环境变量配置）
        return llm_api.generate_questions(prompt, count, difficulty, topic)

def generate_questions_batch_from_prompts(prompts: List[str], db: Session = None, count_per_prompt: int = 3) -> List[Dict[str, Any]]:
    """
    便捷函数：批量生成题目
    
    Args:
        prompts: 提示列表
        db: 数据库会话，用于获取配置
        count_per_prompt: 每个提示生成的题目数量
    
    Returns:
        所有生成的题目列表
    """
    if db:
        # 使用数据库配置创建新的API实例
        api_instance = LLMAPI(db)
        return api_instance.generate_questions_batch(prompts, count_per_prompt)
    else:
        # 使用全局实例（环境变量配置）
        return llm_api.generate_questions_batch(prompts, count_per_prompt)